{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2020 NVIDIA Corporation. All Rights Reserved.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "# =============================================================================="
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png\" style=\"width: 90px; float: right;\">\n",
    "\n",
    "# NVTabular demo on Rossmann data\n",
    "\n",
    "## Overview\n",
    "\n",
    "NVTabular is a feature engineering and preprocessing library for tabular data designed to quickly and easily manipulate terabyte scale datasets used to train deep learning based recommender systems.  It provides a high level abstraction to simplify code and accelerates computation on the GPU using the RAPIDS cuDF library."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Learning objectives\n",
    "\n",
    "This notebook demonstrates the steps for carrying out data preprocessing, transformation and loading with NVTabular on the Kaggle Rossmann [dataset](https://www.kaggle.com/c/rossmann-store-sales/overview).  Rossmann operates over 3,000 drug stores in 7 European countries. Historical sales data for 1,115 Rossmann stores are provided. The task is to forecast the \"Sales\" column for the test set. \n",
    "\n",
    "The following example will illustrate how to use NVTabular to preprocess and load tabular data for training neural networks in both PyTorch and TensorFlow. We'll use a [dataset built by FastAI](https://github.com/fastai/fastai/blob/master/courses/dl1/lesson3-rossman.ipynb) for solving the [Kaggle Rossmann Store Sales competition](https://www.kaggle.com/c/rossmann-store-sales). Some pandas preprocessing is required to build the appropriate feature set, so make sure to run [rossmann-store-sales-preproc.ipynb](./rossmann-store-sales-preproc.ipynb) first before going through this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import nvtabular as nvt\n",
    "import glob"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preparing our dataset\n",
    "Let's start by defining some of the a priori information about our data, including its schema (what columns to use and what sorts of variables they represent), as well as the location of the files corresponding to some particular sampling from this schema. Note that throughout, I'll use UPPERCASE variables to represent this sort of a priori information that you might usually encode using commandline arguments or config files."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_DIR = os.environ.get(\"OUTPUT_DATA_DIR\", \"./data\")\n",
    "\n",
    "CATEGORICAL_COLUMNS = [\n",
    "    'Store', 'DayOfWeek', 'Year', 'Month', 'Day', 'StateHoliday', 'CompetitionMonthsOpen',\n",
    "    'Promo2Weeks', 'StoreType', 'Assortment', 'PromoInterval', 'CompetitionOpenSinceYear', 'Promo2SinceYear',\n",
    "    'State', 'Week', 'Events', 'Promo_fw', 'Promo_bw', 'StateHoliday_fw', 'StateHoliday_bw',\n",
    "    'SchoolHoliday_fw', 'SchoolHoliday_bw'\n",
    "]\n",
    "\n",
    "CONTINUOUS_COLUMNS = [\n",
    "    'CompetitionDistance', 'Max_TemperatureC', 'Mean_TemperatureC', 'Min_TemperatureC',\n",
    "   'Max_Humidity', 'Mean_Humidity', 'Min_Humidity', 'Max_Wind_SpeedKm_h', \n",
    "   'Mean_Wind_SpeedKm_h', 'CloudCover', 'trend', 'trend_DE',\n",
    "   'AfterStateHoliday', 'BeforeStateHoliday', 'Promo', 'SchoolHoliday'\n",
    "]\n",
    "LABEL_COLUMNS = ['Sales']\n",
    "\n",
    "COLUMNS = CATEGORICAL_COLUMNS + CONTINUOUS_COLUMNS + LABEL_COLUMNS"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "What files are available to train on in our data directory?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ross_pre  test.csv  train.csv  valid.csv\r\n"
     ]
    }
   ],
   "source": [
    "! ls $DATA_DIR"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`train.csv` and `valid.csv` seem like good candidates, let's use those."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "TRAIN_PATH = os.path.join(DATA_DIR, 'train.csv')\n",
    "VALID_PATH = os.path.join(DATA_DIR, 'valid.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Workflows and Preprocessing\n",
    "A `Workflow` is used to represent the chains of feature engineering and preprocessing operations performed on a dataset, and is instantiated with a description of the dataset's schema so that it can keep track of how columns transform with each operation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# note that here, we want to perform a normalization transformation on the label\n",
    "# column. Since NVT doesn't support transforming label columns right now, we'll\n",
    "# pretend it's a regular continuous column during our feature engineering phase\n",
    "proc = nvt.Workflow(\n",
    "    cat_names=CATEGORICAL_COLUMNS,\n",
    "    cont_names=CONTINUOUS_COLUMNS,\n",
    "    label_name=LABEL_COLUMNS\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Ops\n",
    "We add operations to a `Workflow` by leveraging the `add_(cat|cont)_feature` and `add_(cat|cont)_preprocess` methods for categorical and continuous variables, respectively. When we're done adding ops, we call the `finalize` method to let the `Workflow` build  a representation of its outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "proc.add_cont_feature(nvt.ops.FillMissing())\n",
    "proc.add_cont_preprocess(nvt.ops.LogOp(columns=LABEL_COLUMNS))\n",
    "proc.add_cont_preprocess(nvt.ops.Normalize())\n",
    "proc.add_cat_preprocess(nvt.ops.Categorify())\n",
    "proc.finalize()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Datasets\n",
    "In general, the `Op`s in our `Workflow` will require measurements of statistical properties of our data in order to be leveraged. For example, the `Normalize` op requires measurements of the dataset mean and standard deviation, and the `Categorify` op requires an accounting of all the categories a particular feature can manifest. However, we frequently need to measure these properties across datasets which are too large to fit into GPU memory (or CPU memory for that matter) at once.\n",
    "\n",
    "NVTabular solves this by providing the `Dataset` class, which breaks a set of parquet or csv files into into a collection of `cudf.DataFrame` chunks that can fit in device memory.  Under the hood, the data decomposition corresponds to the construction of a [dask_cudf.DataFrame](https://docs.rapids.ai/api/cudf/stable/dask-cudf.html) object.  By representing our dataset as a lazily-evaluated [Dask](https://dask.org/) collection, we can handle the calculation of complex global statistics (and later, can also iterate over the partitions while feeding data into a neural network).\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = nvt.Dataset(TRAIN_PATH)\n",
    "valid_dataset = nvt.Dataset(VALID_PATH)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have our datasets, we'll apply our `Workflow` to them and save the results out to parquet files for fast reading at train time. We'll also measure and record statistics on our training set using the `record_stats=True` kwarg so that our `Workflow` can use them at apply time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "PREPROCESS_DIR = os.path.join(DATA_DIR, 'ross_pre')\n",
    "PREPROCESS_DIR_TRAIN = os.path.join(PREPROCESS_DIR, 'train')\n",
    "PREPROCESS_DIR_VALID = os.path.join(PREPROCESS_DIR, 'valid')\n",
    "\n",
    "! rm -rf $PREPROCESS_DIR # remove previous trials\n",
    "! mkdir -p $PREPROCESS_DIR_TRAIN\n",
    "! mkdir -p $PREPROCESS_DIR_VALID"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "proc.apply(train_dataset, record_stats=True, output_path=PREPROCESS_DIR_TRAIN, shuffle=nvt.io.Shuffle.PER_WORKER, out_files_per_proc=2)\n",
    "proc.apply(valid_dataset, record_stats=False, output_path=PREPROCESS_DIR_VALID, shuffle=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Finalize columns\n",
    "The FastAI workflow will use `nvtabular.loader.torch.TorchAsyncItr`, which will map a dataset to its corresponding PyTorch tensors. In order to make sure it runs correctly, we'll call the `create_final_cols` method to let the `Workflow` know to build the output dataset schema, and then we'll be sure to remove instances of the label column that got added to that schema when we performed processing on it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "proc.create_final_cols()\n",
    "# using log op and normalize on sales column causes it to get added to\n",
    "# continuous columns_ctx, so we'll remove it here\n",
    "while True:\n",
    "    try:\n",
    "        proc.columns_ctx['final']['cols']['continuous'].remove(LABEL_COLUMNS[0])\n",
    "    except ValueError:\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training a Network\n",
    "\n",
    "Now that our data is preprocessed and saved out, we can leverage `dataset`s to read through the preprocessed parquet files in an online fashion to train neural networks.\n",
    "\n",
    "We'll start by setting some universal hyperparameters for our model and optimizer. These settings will be shared across all of the frameworks that we explore below."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you're interested in contributing to NVTabular, feel free to take this challenge on and submit a pull request if successful. 12% RMSPE is achievable using the Novograd optimizer, but we know of no Novograd implementation for TensorFlow that supports sparse gradients, and so we are not including that solution below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "EMBEDDING_DROPOUT_RATE = 0.04\n",
    "DROPOUT_RATES = [0.001, 0.01]\n",
    "HIDDEN_DIMS = [1000, 500]\n",
    "BATCH_SIZE = 65536\n",
    "LEARNING_RATE = 0.001\n",
    "EPOCHS = 25\n",
    "\n",
    "# TODO: Calculate on the fly rather than recalling from previous analysis.\n",
    "MAX_SALES_IN_TRAINING_SET = 38722.0\n",
    "MAX_LOG_SALES_PREDICTION = 1.2 * np.log(MAX_SALES_IN_TRAINING_SET + 1.0)\n",
    "\n",
    "# It's possible to use defaults defined within NVTabular.\n",
    "EMBEDDING_TABLE_SHAPES = {\n",
    "    column: shape for column, shape in\n",
    "        nvt.ops.get_embedding_sizes(proc).items()\n",
    "}\n",
    "\n",
    "# Here, however, we will use fast.ai's rule for embedding sizes.\n",
    "for col in EMBEDDING_TABLE_SHAPES:\n",
    "    EMBEDDING_TABLE_SHAPES[col] = (EMBEDDING_TABLE_SHAPES[col][0], min(600, round(1.6 * EMBEDDING_TABLE_SHAPES[col][0] ** 0.56)))\n",
    "\n",
    "TRAIN_PATHS = sorted(glob.glob(os.path.join(PREPROCESS_DIR_TRAIN, '*.parquet')))\n",
    "VALID_PATHS = sorted(glob.glob(os.path.join(PREPROCESS_DIR_VALID, '*.parquet')))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following shows the cardinality of each categorical variable along with its associated embedding size. Each entry is of the form `(cardinality, embedding_size)`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Assortment': (4, 3),\n",
       " 'CompetitionMonthsOpen': (26, 10),\n",
       " 'CompetitionOpenSinceYear': (24, 9),\n",
       " 'Day': (32, 11),\n",
       " 'DayOfWeek': (8, 5),\n",
       " 'Events': (22, 9),\n",
       " 'Month': (13, 7),\n",
       " 'Promo2SinceYear': (9, 5),\n",
       " 'Promo2Weeks': (27, 10),\n",
       " 'PromoInterval': (4, 3),\n",
       " 'Promo_bw': (7, 5),\n",
       " 'Promo_fw': (7, 5),\n",
       " 'SchoolHoliday_bw': (9, 5),\n",
       " 'SchoolHoliday_fw': (9, 5),\n",
       " 'State': (13, 7),\n",
       " 'StateHoliday': (3, 3),\n",
       " 'StateHoliday_bw': (4, 3),\n",
       " 'StateHoliday_fw': (4, 3),\n",
       " 'Store': (1116, 81),\n",
       " 'StoreType': (5, 4),\n",
       " 'Week': (53, 15),\n",
       " 'Year': (4, 3)}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "EMBEDDING_TABLE_SHAPES"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Choose a Framework\n",
    "\n",
    "We're now ready to move on to framework-specific code.\n",
    "\n",
    "**The code for each framework can be run independently of the others, so feel free to skip to your framework of choice.**\n",
    "\n",
    " * [TensorFlow](#TensorFlow)\n",
    " * [PyTorch](#PyTorch)\n",
    " * [fast.ai](#fast.ai)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TensorFlow\n",
    "<a id=\"TensorFlow\"></a>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TensorFlow: Preparing Datasets\n",
    "\n",
    "`KerasSequenceLoader` wraps a lightweight iterator around a `dataset` object to handle chunking, shuffling, and application of any workflows (which can be applied online as a preprocessing step). For column names, can use either a list of string names or a list of TensorFlow `feature_columns` that will be used to feed the network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "# we can control how much memory to give tensorflow with this environment variable\n",
    "# IMPORTANT: make sure you do this before you initialize TF's runtime, otherwise\n",
    "# it's too late and TF will have claimed all free GPU memory\n",
    "os.environ['TF_MEMORY_ALLOCATION'] = \"8192\" # explicit MB\n",
    "os.environ['TF_MEMORY_ALLOCATION'] = \"0.5\" # fraction of free memory\n",
    "from nvtabular.loader.tensorflow import KerasSequenceLoader, KerasSequenceValidater\n",
    "\n",
    "# cheap wrapper to keep things some semblance of neat\n",
    "def make_categorical_embedding_column(name, dictionary_size, embedding_dim):\n",
    "    return tf.feature_column.embedding_column(\n",
    "        tf.feature_column.categorical_column_with_identity(name, dictionary_size),\n",
    "        embedding_dim\n",
    "    )\n",
    "\n",
    "# instantiate our columns\n",
    "categorical_columns = [\n",
    "    make_categorical_embedding_column(name, *EMBEDDING_TABLE_SHAPES[name]) for\n",
    "        name in CATEGORICAL_COLUMNS\n",
    "]\n",
    "continuous_columns = [\n",
    "    tf.feature_column.numeric_column(name, (1,)) for name in CONTINUOUS_COLUMNS\n",
    "]\n",
    "\n",
    "# feed them to our datasets\n",
    "train_dataset_tf = KerasSequenceLoader(\n",
    "    TRAIN_PATHS, # you could also use a glob pattern\n",
    "    feature_columns=categorical_columns+continuous_columns,\n",
    "    batch_size=BATCH_SIZE,\n",
    "    label_names=LABEL_COLUMNS,\n",
    "    shuffle=True,\n",
    "    buffer_size=0.06 # amount of data, as a fraction of GPU memory, to load at once\n",
    ")\n",
    "\n",
    "valid_dataset_tf = KerasSequenceLoader(\n",
    "    VALID_PATHS, # you could also use a glob pattern\n",
    "    feature_columns=categorical_columns+continuous_columns,\n",
    "    batch_size=BATCH_SIZE*4,\n",
    "    label_names=LABEL_COLUMNS,\n",
    "    shuffle=False,\n",
    "    buffer_size=0.06 # amount of data, as a fraction of GPU memory, to load at once\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TensorFlow: Defining a Model\n",
    "\n",
    "Using Keras, we can define the layers of our model and their parameters explicitly. Here, for the sake of consistency, we'll mimic fast.ai's [TabularModel](https://docs.fast.ai/tabular.learner.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# DenseFeatures layer needs a dictionary of {feature_name: input}\n",
    "categorical_inputs = {}\n",
    "for column_name in CATEGORICAL_COLUMNS:\n",
    "    categorical_inputs[column_name] = tf.keras.Input(name=column_name, shape=(1,), dtype=tf.int64)\n",
    "categorical_embedding_layer = tf.keras.layers.DenseFeatures(categorical_columns)\n",
    "categorical_x = categorical_embedding_layer(categorical_inputs)\n",
    "categorical_x = tf.keras.layers.Dropout(EMBEDDING_DROPOUT_RATE)(categorical_x)\n",
    "\n",
    "# Just concatenating continuous, so can use a list\n",
    "continuous_inputs = []\n",
    "for column_name in CONTINUOUS_COLUMNS:\n",
    "    continuous_inputs.append(tf.keras.Input(name=column_name, shape=(1,), dtype=tf.float32))\n",
    "continuous_embedding_layer = tf.keras.layers.Concatenate(axis=1)\n",
    "continuous_x = continuous_embedding_layer(continuous_inputs)\n",
    "continuous_x = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1)(continuous_x)\n",
    "\n",
    "# concatenate and build MLP\n",
    "x = tf.keras.layers.Concatenate(axis=1)([categorical_x, continuous_x])\n",
    "for dim, dropout_rate in zip(HIDDEN_DIMS, DROPOUT_RATES):\n",
    "    x = tf.keras.layers.Dense(dim, activation='relu')(x)\n",
    "    x = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1)(x)\n",
    "    x = tf.keras.layers.Dropout(dropout_rate)(x)\n",
    "x = tf.keras.layers.Dense(1, activation='linear')(x)\n",
    "\n",
    "# TODO: Initialize model weights to fix saturation issues.\n",
    "# For now, we'll just scale the output of our model directly before\n",
    "# hitting the sigmoid.\n",
    "x = 0.1 * x\n",
    "\n",
    "x = MAX_LOG_SALES_PREDICTION * tf.keras.activations.sigmoid(x)\n",
    "\n",
    "# combine all our inputs into a single list\n",
    "# (note that you can still use .fit, .predict, etc. on a dict\n",
    "# that maps input tensor names to input values)\n",
    "inputs = list(categorical_inputs.values()) + continuous_inputs\n",
    "tf_model = tf.keras.Model(inputs=inputs, outputs=x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TensorFlow: Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rmspe_tf(y_true, y_pred):\n",
    "    # map back into \"true\" space by undoing transform\n",
    "    y_true = tf.exp(y_true) - 1\n",
    "    y_pred = tf.exp(y_pred) - 1\n",
    "\n",
    "    percent_error = (y_true - y_pred) / y_true\n",
    "    return tf.sqrt(tf.reduce_mean(percent_error**2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)\n",
    "tf_model.compile(optimizer, 'mse', metrics=[rmspe_tf])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/25\n",
      "13/13 [==============================] - 4s 278ms/step - loss: 6.0125 - rmspe_tf: 0.8914 - val_loss: 6.1718 - val_rmspe_tf: 0.9068\n",
      "Epoch 2/25\n",
      "13/13 [==============================] - 6s 470ms/step - loss: 5.2308 - rmspe_tf: 0.8905 - val_loss: 4.6752 - val_rmspe_tf: 0.8793\n",
      "Epoch 3/25\n",
      "13/13 [==============================] - 18s 1s/step - loss: 4.5764 - rmspe_tf: 0.8769 - val_loss: 4.0147 - val_rmspe_tf: 0.8604\n",
      "Epoch 4/25\n",
      "13/13 [==============================] - 16s 1s/step - loss: 3.7594 - rmspe_tf: 0.8504 - val_loss: 3.1879 - val_rmspe_tf: 0.8263\n",
      "Epoch 5/25\n",
      "13/13 [==============================] - 6s 494ms/step - loss: 2.7753 - rmspe_tf: 0.8028 - val_loss: 2.1554 - val_rmspe_tf: 0.7611\n",
      "Epoch 6/25\n",
      "13/13 [==============================] - 5s 411ms/step - loss: 1.7590 - rmspe_tf: 0.7221 - val_loss: 1.1834 - val_rmspe_tf: 0.6504\n",
      "Epoch 7/25\n",
      "13/13 [==============================] - 4s 275ms/step - loss: 0.9113 - rmspe_tf: 0.5963 - val_loss: 0.4885 - val_rmspe_tf: 0.4849\n",
      "Epoch 8/25\n",
      "13/13 [==============================] - 5s 394ms/step - loss: 0.3716 - rmspe_tf: 0.4380 - val_loss: 0.1876 - val_rmspe_tf: 0.3339\n",
      "Epoch 9/25\n",
      "13/13 [==============================] - 5s 397ms/step - loss: 0.1259 - rmspe_tf: 0.2894 - val_loss: 0.0627 - val_rmspe_tf: 0.2305\n",
      "Epoch 10/25\n",
      "13/13 [==============================] - 4s 294ms/step - loss: 0.0510 - rmspe_tf: 0.2263 - val_loss: 0.0487 - val_rmspe_tf: 0.2156\n",
      "Epoch 11/25\n",
      "13/13 [==============================] - 3s 206ms/step - loss: 0.0371 - rmspe_tf: 0.2373 - val_loss: 0.0447 - val_rmspe_tf: 0.2192\n",
      "Epoch 12/25\n",
      "13/13 [==============================] - 3s 199ms/step - loss: 0.0350 - rmspe_tf: 0.2269 - val_loss: 0.0513 - val_rmspe_tf: 0.2567\n",
      "Epoch 13/25\n",
      "13/13 [==============================] - 2s 189ms/step - loss: 0.0330 - rmspe_tf: 0.2072 - val_loss: 0.0465 - val_rmspe_tf: 0.2418\n",
      "Epoch 14/25\n",
      "13/13 [==============================] - 2s 177ms/step - loss: 0.0313 - rmspe_tf: 0.2094 - val_loss: 0.0410 - val_rmspe_tf: 0.2078\n",
      "Epoch 15/25\n",
      "13/13 [==============================] - 2s 173ms/step - loss: 0.0298 - rmspe_tf: 0.2043 - val_loss: 0.0430 - val_rmspe_tf: 0.2290\n",
      "Epoch 16/25\n",
      "13/13 [==============================] - 3s 197ms/step - loss: 0.0289 - rmspe_tf: 0.2112 - val_loss: 0.0417 - val_rmspe_tf: 0.2251\n",
      "Epoch 17/25\n",
      "13/13 [==============================] - 2s 160ms/step - loss: 0.0281 - rmspe_tf: 0.1864 - val_loss: 0.0481 - val_rmspe_tf: 0.2554\n",
      "Epoch 18/25\n",
      "13/13 [==============================] - 2s 160ms/step - loss: 0.0273 - rmspe_tf: 0.1959 - val_loss: 0.0393 - val_rmspe_tf: 0.2190\n",
      "Epoch 19/25\n",
      "13/13 [==============================] - 2s 167ms/step - loss: 0.0262 - rmspe_tf: 0.1923 - val_loss: 0.0464 - val_rmspe_tf: 0.2512\n",
      "Epoch 20/25\n",
      "13/13 [==============================] - 2s 181ms/step - loss: 0.0260 - rmspe_tf: 0.1978 - val_loss: 0.0472 - val_rmspe_tf: 0.2549\n",
      "Epoch 21/25\n",
      "13/13 [==============================] - 2s 162ms/step - loss: 0.0255 - rmspe_tf: 0.1891 - val_loss: 0.0449 - val_rmspe_tf: 0.2469\n",
      "Epoch 22/25\n",
      "13/13 [==============================] - 2s 170ms/step - loss: 0.0253 - rmspe_tf: 0.1858 - val_loss: 0.0371 - val_rmspe_tf: 0.2104\n",
      "Epoch 23/25\n",
      "13/13 [==============================] - 2s 160ms/step - loss: 0.0252 - rmspe_tf: 0.1948 - val_loss: 0.0534 - val_rmspe_tf: 0.2788\n",
      "Epoch 24/25\n",
      "13/13 [==============================] - 2s 171ms/step - loss: 0.0267 - rmspe_tf: 0.1804 - val_loss: 0.0382 - val_rmspe_tf: 0.1884\n",
      "Epoch 25/25\n",
      "13/13 [==============================] - 2s 155ms/step - loss: 0.0252 - rmspe_tf: 0.1940 - val_loss: 0.0344 - val_rmspe_tf: 0.1996\n",
      "CPU times: user 2min 27s, sys: 27.2 s, total: 2min 54s\n",
      "Wall time: 2min 3s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)\n",
    "tf_model.compile(optimizer, 'mse', metrics=[rmspe_tf])\n",
    "\n",
    "validation_callback = KerasSequenceValidater(valid_dataset_tf)\n",
    "history = tf_model.fit(\n",
    "    train_dataset_tf,\n",
    "    callbacks=[validation_callback],\n",
    "    epochs=EPOCHS,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PyTorch<a id=\"PyTorch\"></a>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PyTorch: Preparing Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from nvtabular.loader.torch import TorchAsyncItr, DLDataLoader\n",
    "from nvtabular.framework_utils.torch.models import Model\n",
    "from nvtabular.framework_utils.torch.utils import process_epoch\n",
    "\n",
    "# TensorItrDataset returns a single batch of x_cat, x_cont, y.\n",
    "collate_fn = lambda x: x\n",
    "\n",
    "train_dataset = TorchAsyncItr(nvt.Dataset(TRAIN_PATHS), batch_size=BATCH_SIZE, cats=CATEGORICAL_COLUMNS, conts=CONTINUOUS_COLUMNS, labels=LABEL_COLUMNS)\n",
    "train_loader = DLDataLoader(train_dataset, batch_size=None, collate_fn=collate_fn, pin_memory=False, num_workers=0)\n",
    "\n",
    "valid_dataset = TorchAsyncItr(nvt.Dataset(VALID_PATHS), batch_size=BATCH_SIZE, cats=CATEGORICAL_COLUMNS, conts=CONTINUOUS_COLUMNS, labels=LABEL_COLUMNS)\n",
    "valid_loader = DLDataLoader(valid_dataset, batch_size=None, collate_fn=collate_fn, pin_memory=False, num_workers=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PyTorch: Defining a Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Model(\n",
    "    embedding_table_shapes=EMBEDDING_TABLE_SHAPES,\n",
    "    num_continuous=len(CONTINUOUS_COLUMNS),\n",
    "    emb_dropout=EMBEDDING_DROPOUT_RATE,\n",
    "    layer_hidden_dims=HIDDEN_DIMS,\n",
    "    layer_dropout_rates=DROPOUT_RATES,\n",
    "    max_output=MAX_LOG_SALES_PREDICTION\n",
    ").to('cuda')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PyTorch: Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rmspe_func(y_pred, y):\n",
    "    \"Return y_pred and y to non-log space and compute RMSPE\"\n",
    "    y_pred, y = torch.exp(y_pred) - 1, torch.exp(y) - 1\n",
    "    pct_var = (y_pred - y) / y\n",
    "    return (pct_var**2).mean().pow(0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 00. Train loss: 7.6541. Train RMSPE: 2.3277. Valid loss: 3.8562. Valid RMSPE: 0.8409.\n",
      "Epoch 01. Train loss: 3.8174. Train RMSPE: 0.8135. Valid loss: 2.6833. Valid RMSPE: 0.7823.\n",
      "Epoch 02. Train loss: 2.4370. Train RMSPE: 0.7532. Valid loss: 1.6169. Valid RMSPE: 0.6843.\n",
      "Epoch 03. Train loss: 1.1605. Train RMSPE: 0.6109. Valid loss: 0.6108. Valid RMSPE: 0.4976.\n",
      "Epoch 04. Train loss: 0.3830. Train RMSPE: 0.4844. Valid loss: 0.2226. Valid RMSPE: 0.4973.\n",
      "Epoch 05. Train loss: 0.2199. Train RMSPE: 0.6033. Valid loss: 0.1974. Valid RMSPE: 0.5660.\n",
      "Epoch 06. Train loss: 0.1954. Train RMSPE: 0.5941. Valid loss: 0.1653. Valid RMSPE: 0.4996.\n",
      "Epoch 07. Train loss: 0.1642. Train RMSPE: 0.5099. Valid loss: 0.1389. Valid RMSPE: 0.4282.\n",
      "Epoch 08. Train loss: 0.1489. Train RMSPE: 0.4298. Valid loss: 0.1250. Valid RMSPE: 0.3837.\n",
      "Epoch 09. Train loss: 0.1369. Train RMSPE: 0.4224. Valid loss: 0.1160. Valid RMSPE: 0.3943.\n",
      "Epoch 10. Train loss: 0.1269. Train RMSPE: 0.4047. Valid loss: 0.1068. Valid RMSPE: 0.3653.\n",
      "Epoch 11. Train loss: 0.1180. Train RMSPE: 0.4122. Valid loss: 0.1038. Valid RMSPE: 0.3748.\n",
      "Epoch 12. Train loss: 0.1123. Train RMSPE: 0.3866. Valid loss: 0.0939. Valid RMSPE: 0.3489.\n",
      "Epoch 13. Train loss: 0.1067. Train RMSPE: 0.3759. Valid loss: 0.0976. Valid RMSPE: 0.3089.\n",
      "Epoch 14. Train loss: 0.1104. Train RMSPE: 0.4015. Valid loss: 0.0850. Valid RMSPE: 0.3047.\n",
      "Epoch 15. Train loss: 0.1000. Train RMSPE: 0.3564. Valid loss: 0.0877. Valid RMSPE: 0.3471.\n",
      "Epoch 16. Train loss: 0.0950. Train RMSPE: 0.3490. Valid loss: 0.0781. Valid RMSPE: 0.3194.\n",
      "Epoch 17. Train loss: 0.0938. Train RMSPE: 0.3709. Valid loss: 0.0781. Valid RMSPE: 0.2871.\n",
      "Epoch 18. Train loss: 0.0985. Train RMSPE: 0.3634. Valid loss: 0.1720. Valid RMSPE: 0.3634.\n",
      "Epoch 19. Train loss: 0.1039. Train RMSPE: 0.3729. Valid loss: 0.0805. Valid RMSPE: 0.3397.\n",
      "Epoch 20. Train loss: 0.0817. Train RMSPE: 0.3229. Valid loss: 0.0713. Valid RMSPE: 0.2713.\n",
      "Epoch 21. Train loss: 0.0779. Train RMSPE: 0.3260. Valid loss: 0.0633. Valid RMSPE: 0.2670.\n",
      "Epoch 22. Train loss: 0.0759. Train RMSPE: 0.3280. Valid loss: 0.0610. Valid RMSPE: 0.2712.\n",
      "Epoch 23. Train loss: 0.0724. Train RMSPE: 0.3059. Valid loss: 0.0588. Valid RMSPE: 0.2574.\n",
      "Epoch 24. Train loss: 0.0710. Train RMSPE: 0.3046. Valid loss: 0.0577. Valid RMSPE: 0.2619.\n",
      "CPU times: user 40.8 s, sys: 14.6 s, total: 55.4 s\n",
      "Wall time: 44.9 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "    train_loss, y_pred, y = process_epoch(train_loader, model, train=True, optimizer=optimizer)\n",
    "    train_rmspe = rmspe_func(y_pred, y)\n",
    "    valid_loss, y_pred, y = process_epoch(valid_loader, model, train=False)\n",
    "    valid_rmspe = rmspe_func(y_pred, y)\n",
    "    print(f'Epoch {epoch:02d}. Train loss: {train_loss:.4f}. Train RMSPE: {train_rmspe:.4f}. Valid loss: {valid_loss:.4f}. Valid RMSPE: {valid_rmspe:.4f}.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## fast.ai<a id=\"fast.ai\"></a>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### fast.ai: Preparing Datasets\n",
    "\n",
    "AsyncTensorBatchDatasetItr maps a symbolic dataset object to `cat_features`, `cont_features`, `labels` PyTorch tenosrs by iterating through the dataset and concatenating the results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2.0.16'"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import fastai\n",
    "\n",
    "fastai.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from nvtabular.loader.torch import TorchAsyncItr, DLDataLoader\n",
    "from fastai.tabular.data import TabularDataLoaders\n",
    "from fastai.tabular.model import TabularModel\n",
    "from fastai.basics import Learner\n",
    "from fastai.basics import MSELossFlat\n",
    "from fastai.callback.progress import ProgressCallback\n",
    "\n",
    "def make_batched_dataloader(paths, columns, batch_size):\n",
    "    dataset = nvt.Dataset(paths)\n",
    "    ds_batch_sets = TorchAsyncItr(dataset, \n",
    "                                  batch_size=batch_size, \n",
    "                                  cats=CATEGORICAL_COLUMNS, \n",
    "                                  conts=CONTINUOUS_COLUMNS, \n",
    "                                  labels=LABEL_COLUMNS)\n",
    "    return DLDataLoader(\n",
    "        ds_batch_sets,\n",
    "        batch_size=None,\n",
    "        pin_memory=False,\n",
    "        num_workers=0\n",
    "    )\n",
    "\n",
    "train_dataset_pt = make_batched_dataloader(TRAIN_PATHS, COLUMNS, BATCH_SIZE)\n",
    "valid_dataset_pt = make_batched_dataloader(VALID_PATHS, COLUMNS, BATCH_SIZE*4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "databunch = TabularDataLoaders(train_dataset_pt, valid_dataset_pt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### fast.ai: Defining a Model\n",
    "\n",
    "Next we'll need to define the inputs that will feed our model and build an architecture on top of them. For now, we'll just stick to a simple MLP model.\n",
    "\n",
    "Using FastAI's `TabularModel`, we can build an MLP under the hood by defining its high-level characteristics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "pt_model = TabularModel(\n",
    "    emb_szs=list(EMBEDDING_TABLE_SHAPES.values()),\n",
    "    n_cont=len(CONTINUOUS_COLUMNS),\n",
    "    out_sz=1,\n",
    "    layers=HIDDEN_DIMS,\n",
    "    ps=DROPOUT_RATES,\n",
    "    use_bn=True,\n",
    "    embed_p=EMBEDDING_DROPOUT_RATE,\n",
    "    y_range=torch.tensor([0.0, MAX_LOG_SALES_PREDICTION]),\n",
    ").cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### fast.ai: Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>exp_rmspe</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>exp_rmspe</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.215761</td>\n",
       "      <td>0.164155</td>\n",
       "      <td>0.331342</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.127483</td>\n",
       "      <td>0.071673</td>\n",
       "      <td>0.246229</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.088330</td>\n",
       "      <td>0.032615</td>\n",
       "      <td>0.181555</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.066506</td>\n",
       "      <td>0.025614</td>\n",
       "      <td>0.167483</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.052737</td>\n",
       "      <td>0.023265</td>\n",
       "      <td>0.160165</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.043378</td>\n",
       "      <td>0.022198</td>\n",
       "      <td>0.157680</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.036673</td>\n",
       "      <td>0.021151</td>\n",
       "      <td>0.154265</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.031675</td>\n",
       "      <td>0.020010</td>\n",
       "      <td>0.150160</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.027837</td>\n",
       "      <td>0.019383</td>\n",
       "      <td>0.147319</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.024828</td>\n",
       "      <td>0.018521</td>\n",
       "      <td>0.144329</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.022422</td>\n",
       "      <td>0.017908</td>\n",
       "      <td>0.142639</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>0.020478</td>\n",
       "      <td>0.017056</td>\n",
       "      <td>0.137023</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>0.018876</td>\n",
       "      <td>0.016584</td>\n",
       "      <td>0.135137</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>0.017551</td>\n",
       "      <td>0.016209</td>\n",
       "      <td>0.133011</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>0.016442</td>\n",
       "      <td>0.015912</td>\n",
       "      <td>0.132330</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15</td>\n",
       "      <td>0.015490</td>\n",
       "      <td>0.015565</td>\n",
       "      <td>0.130401</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16</td>\n",
       "      <td>0.014674</td>\n",
       "      <td>0.015242</td>\n",
       "      <td>0.128135</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>17</td>\n",
       "      <td>0.013970</td>\n",
       "      <td>0.015026</td>\n",
       "      <td>0.128424</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>18</td>\n",
       "      <td>0.013368</td>\n",
       "      <td>0.014852</td>\n",
       "      <td>0.126099</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>19</td>\n",
       "      <td>0.012857</td>\n",
       "      <td>0.014560</td>\n",
       "      <td>0.124652</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>0.012413</td>\n",
       "      <td>0.014460</td>\n",
       "      <td>0.125070</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>21</td>\n",
       "      <td>0.012033</td>\n",
       "      <td>0.014333</td>\n",
       "      <td>0.122164</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>22</td>\n",
       "      <td>0.011745</td>\n",
       "      <td>0.013983</td>\n",
       "      <td>0.121915</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>23</td>\n",
       "      <td>0.011450</td>\n",
       "      <td>0.014362</td>\n",
       "      <td>0.127220</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>24</td>\n",
       "      <td>0.011145</td>\n",
       "      <td>0.014084</td>\n",
       "      <td>0.122282</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from fastai.torch_core import flatten_check\n",
    "\n",
    "def exp_rmspe(pred, targ):\n",
    "    \"Exp RMSE between `pred` and `targ`.\"\n",
    "    pred,targ = flatten_check(pred,targ)\n",
    "    pred, targ = torch.exp(pred)-1, torch.exp(targ)-1\n",
    "    pct_var = (targ - pred)/targ\n",
    "    return torch.sqrt((pct_var**2).mean())\n",
    "\n",
    "loss_func = MSELossFlat()\n",
    "learner = Learner(databunch, pt_model, loss_func=loss_func, metrics=[exp_rmspe], cbs=ProgressCallback())\n",
    "learner.fit(EPOCHS, LEARNING_RATE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "file_extension": ".py",
  "kernelspec": {
   "display_name": "rapids",
   "language": "python",
   "name": "myenv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  },
  "mimetype": "text/x-python",
  "name": "python",
  "npconvert_exporter": "python",
  "pygments_lexer": "ipython3",
  "version": 3
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
